{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 7.  ,  0.27,  0.36, ...,  0.45,  8.8 ,  6.  ],\n",
       "       [ 6.3 ,  0.3 ,  0.34, ...,  0.49,  9.5 ,  6.  ],\n",
       "       [ 8.1 ,  0.28,  0.4 , ...,  0.44, 10.1 ,  6.  ],\n",
       "       ...,\n",
       "       [ 6.5 ,  0.24,  0.19, ...,  0.46,  9.4 ,  6.  ],\n",
       "       [ 5.5 ,  0.29,  0.3 , ...,  0.38, 12.8 ,  7.  ],\n",
       "       [ 6.  ,  0.21,  0.38, ...,  0.32, 11.8 ,  6.  ]], dtype=float32)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载表格数据有三种方式：\n",
    "# csv model； numpy ；pandas\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import csv\n",
    "\n",
    "wine_path = '../data/p1ch4/tabular-wine/winequality-white.csv'\n",
    "wineq_numpy = np.loadtxt(wine_path,dtype = np.float32,delimiter=';',skiprows=1)\n",
    "wineq_numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4898, 12)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "col_list = next(csv.reader(open(wine_path),delimiter=';'))  # 得到第一行\n",
    "\n",
    "wineq_numpy.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['fixed acidity',\n",
       " 'volatile acidity',\n",
       " 'citric acid',\n",
       " 'residual sugar',\n",
       " 'chlorides',\n",
       " 'free sulfur dioxide',\n",
       " 'total sulfur dioxide',\n",
       " 'density',\n",
       " 'pH',\n",
       " 'sulphates',\n",
       " 'alcohol',\n",
       " 'quality']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "col_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4898, 12]), torch.float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wineq =torch.from_numpy(wineq_numpy)\n",
    "wineq.shape,wineq.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 7.0000,  0.2700,  0.3600,  ...,  3.0000,  0.4500,  8.8000],\n",
       "         [ 6.3000,  0.3000,  0.3400,  ...,  3.3000,  0.4900,  9.5000],\n",
       "         [ 8.1000,  0.2800,  0.4000,  ...,  3.2600,  0.4400, 10.1000],\n",
       "         ...,\n",
       "         [ 6.5000,  0.2400,  0.1900,  ...,  2.9900,  0.4600,  9.4000],\n",
       "         [ 5.5000,  0.2900,  0.3000,  ...,  3.3400,  0.3800, 12.8000],\n",
       "         [ 6.0000,  0.2100,  0.3800,  ...,  3.2600,  0.3200, 11.8000]]),\n",
       " torch.Size([4898, 11]))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = wineq[:,:-1]         # 将最后一列；分离出来\n",
    "data,data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([6., 6., 6.,  ..., 6., 7., 6.]), torch.Size([4898]))"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target = wineq[:,-1]\n",
    "target,target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6, 6, 6,  ..., 6, 7, 6])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 可以把分数转换为分数 -->转为整型\n",
    "\n",
    "target = wineq[:,-1].long()\n",
    "target\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 1., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.]])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 把结果转换为one-hot编码\n",
    "\n",
    "target_onehot = torch.zeros(target.shape[0],10)\n",
    "target_onehot.scatter_(1,target.unsqueeze(1),1.0)   \n",
    "\n",
    "\n",
    "# scatter的参数：\n",
    "# 第一个是要操作的维度，target_onehot的1维度为列，即在每一列上进行\n",
    "# 第二个是参考的标签，target的size是[4898]，只有一个维度，与target_onehot [4898,10]不匹配,所以在1维度上进行扩展，变为了[4898,1]\n",
    "# 第三个是在onehot编码中指定的数字\n",
    "\n",
    "# 实现结果为，在每一个对应相应分数的位置上（如分数为8，则位置为第8个数），将0置1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6.8548e+00, 2.7824e-01, 3.3419e-01, 6.3914e+00, 4.5772e-02, 3.5308e+01,\n",
       "        1.3836e+02, 9.9403e-01, 3.1883e+00, 4.8985e-01, 1.0514e+01])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_mean = torch.mean(data,dim = 0)\n",
    "data_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([7.1211e-01, 1.0160e-02, 1.4646e-02, 2.5726e+01, 4.7733e-04, 2.8924e+02,\n",
       "        1.8061e+03, 8.9455e-06, 2.2801e-02, 1.3025e-02, 1.5144e+00])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_var = torch.var(data,dim = 0)\n",
    "data_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.7209e-01, -8.1764e-02,  2.1325e-01,  ..., -1.2468e+00,\n",
       "         -3.4914e-01, -1.3930e+00],\n",
       "        [-6.5743e-01,  2.1587e-01,  4.7991e-02,  ...,  7.3992e-01,\n",
       "          1.3467e-03, -8.2418e-01],\n",
       "        [ 1.4756e+00,  1.7448e-02,  5.4378e-01,  ...,  4.7502e-01,\n",
       "         -4.3677e-01, -3.3662e-01],\n",
       "        ...,\n",
       "        [-4.2042e-01, -3.7940e-01, -1.1915e+00,  ..., -1.3131e+00,\n",
       "         -2.6152e-01, -9.0544e-01],\n",
       "        [-1.6054e+00,  1.1666e-01, -2.8253e-01,  ...,  1.0048e+00,\n",
       "         -9.6250e-01,  1.8574e+00],\n",
       "        [-1.0129e+00, -6.7703e-01,  3.7852e-01,  ...,  4.7502e-01,\n",
       "         -1.4882e+00,  1.0448e+00]])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_normalized = (data-data_mean) / torch.sqrt(data_var)\n",
    "data_normalized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([False, False, False,  ..., False, False, False])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## 设定分数阈值\n",
    "bad_indexes = target <= 3\n",
    "bad_indexes     # 返回对应每个target 满足条件的对应bool值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4898]), torch.bool, tensor(20))"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bad_indexes.shape,bad_indexes.dtype,bad_indexes.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 11])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bad_data = data[bad_indexes]    # 取出所有data中的bad_indexes 为True的项\n",
    "\n",
    "bad_data.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "bad_data = data[target <= 3]\n",
    "mid_data = data[(target>3)&(target<7)]\n",
    "good_data = data[target >= 7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "bad_mean = torch.mean(bad_data,dim = 0)\n",
    "mid_mean = torch.mean(mid_data,dim = 0)\n",
    "good_mean = torch.mean(good_data,dim = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 0 fixed acidity          7.60   6.89   6.73\n",
      " 1 volatile acidity       0.33   0.28   0.27\n",
      " 2 citric acid            0.34   0.34   0.33\n",
      " 3 residual sugar         6.39   6.71   5.26\n",
      " 4 chlorides              0.05   0.05   0.04\n",
      " 5 free sulfur dioxide   53.33  35.42  34.55\n",
      " 6 total sulfur dioxide 170.60 141.83 125.25\n",
      " 7 density                0.99   0.99   0.99\n",
      " 8 pH                     3.19   3.18   3.22\n",
      " 9 sulphates              0.47   0.49   0.50\n",
      "10 alcohol               10.34  10.26  11.42\n"
     ]
    }
   ],
   "source": [
    "for i,args in enumerate(zip(col_list,bad_mean,mid_mean,good_mean)):\n",
    "    print('{:2} {:20} {:6.2f} {:6.2f} {:6.2f}'.format(i,*args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4898]), torch.bool, tensor(2727))"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 设定中等品质的total sulfur dioxide 的值为阈值\n",
    "total_sulfur_threshold = 141.83\n",
    "total_sulfur_data = data[:,6]             # 取出该列所有数据\n",
    "predicted_indexes = torch.lt(total_sulfur_data,total_sulfur_threshold)   # 在total..那一列根据阈值筛选\n",
    "\n",
    "predicted_indexes.shape,predicted_indexes.dtype,predicted_indexes.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4898]), torch.bool, tensor(3258))"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "actual_indexes = target > 5   \n",
    "\n",
    "actual_indexes.shape,actual_indexes.dtype,actual_indexes.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 看一下根据阈值预测的结果和实际结果的差别\n",
    "\n",
    "n_matches = torch.sum(actual_indexes&predicted_indexes).item()    # 既是预测值，又是实际值的数量\n",
    "n_predicted = torch.sum(predicted_indexes).item()\n",
    "n_actual = torch.sum(actual_indexes).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2018, 0.74000733406674, 0.6193984039287906)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_matches,n_matches/n_predicted,n_matches/n_actual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
